{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Encode Datasets\n",
    "\n",
    "This notebook takes as input a trained encoder, target set, and query set, and saves encoded targets & queries, as well the query-target distance matrix."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "import pandas as pd\n",
    "\n",
    "import primo.models\n",
    "import primo.datasets\n",
    "\n",
    "from tqdm.notebook import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "encoder = primo.models.Encoder('/tf/primo/data/models/encoder-model.h5')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Encode Queries\n",
    "This code loads the query features, encodes them to DNA sequences, and saves the result."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "query_features = pd.read_hdf('/tf/primo/data/queries/features.h5')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "query_seqs = encoder.encode_feature_seqs(query_features)\n",
    "pd.DataFrame(\n",
    "    query_seqs, index=query_features.index, columns=['FeatureSequence']\n",
    ").to_hdf(\n",
    "    '/tf/primo/data/queries/feature_seqs.h5', key='df', mode='w'\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Encode Target Set\n",
    "This code loads the target set's features, encodes them to DNA sequences, calculates distances to each query, and saves the result."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Memory-mapped file that caches the distances between targets and queries\n",
    "dist_store = pd.HDFStore('/tf/primo/data/targets/query_target_dists.h5', complevel=9, mode='w')\n",
    "\n",
    "# Memory-mapped file that stores the DNA sequence encodings of the target features.\n",
    "seq_store = pd.HDFStore('/tf/primo/data/targets/feature_seqs.h5', complevel=9, mode='w')\n",
    "\n",
    "try:\n",
    "    # Target images are split up across 16 files.\n",
    "    # Because these files are so large, and can't all be stored into memory on a single machine, there's some low-level memory-management wizardly happening below.\n",
    "    prefixes = [ \"%x\"%i for i in range(16) ]\n",
    "    for prefix in tqdm(prefixes):\n",
    "        target_features = pd.read_hdf('/tf/open_images/targets/features/targets_%s.h5' % prefix)\n",
    "\n",
    "        # Dictionary that maps queries to euclidean distances for every pairing of query and target.\n",
    "        distances = {}\n",
    "        for query_id, query in query_features.iterrows():\n",
    "            # Calculuate the Euclidean distance between each query and target.\n",
    "            distances[query_id] = np.sqrt(np.square(target_features.values - query.values).sum(1))\n",
    "\n",
    "        df = pd.DataFrame(distances, index=target_features.index)\n",
    "        dist_store.append('df', df)\n",
    "\n",
    "        # Low-level memory mangement\n",
    "        del df, distances\n",
    "        \n",
    "        target_seqs = encoder.encode_feature_seqs(target_features)\n",
    "        df = pd.DataFrame(target_seqs, index=target_features.index, columns=['FeatureSequence'])\n",
    "        seq_store.append('df', df)\n",
    "        del df, target_seqs\n",
    "        \n",
    "        del target_features\n",
    "\n",
    "finally:\n",
    "    dist_store.close()\n",
    "    seq_store.close()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Encode Extended Target Set\n",
    "This code loads the extended target set's features, encodes them to DNA sequences, calculates distances to each query, and saves the result."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Images from the full set of ~9 million images that are not used in either the\n",
    "# \"target\" set or the \"train\" set. These are not hosted by the CVDF.\n",
    "\n",
    "dist_store = pd.HDFStore('/tf/primo/data/extended_targets/query_target_dists.h5', complevel=9, mode='w')\n",
    "seq_store = pd.HDFStore('/tf/primo/data/extended_targets/feature_seqs.h5', complevel=9, mode='w')\n",
    "\n",
    "try:\n",
    "    # Note: Prefixes are lower-cased hexidecimal (e.g. '4b').\n",
    "    prefixes = [ \"%x\"%i for i in range(16) ]\n",
    "    for prefix in tqdm(prefixes):\n",
    "        target_features = pd.read_hdf('/tf/open_images/extended_targets/features/extended_targets_%s.h5' % prefix)\n",
    "\n",
    "        distances = {}\n",
    "        for query_id, query in query_features.iterrows():\n",
    "            distances[query_id] = np.sqrt(np.square(target_features.values - query.values).sum(1))\n",
    "\n",
    "        df = pd.DataFrame(distances, index=target_features.index)\n",
    "        dist_store.append('df', df)\n",
    "        del df, distances\n",
    "        \n",
    "        target_seqs = encoder.encode_feature_seqs(target_features)\n",
    "        df = pd.DataFrame(target_seqs, index=target_features.index, columns=['FeatureSequence'])\n",
    "        seq_store.append('df', df)\n",
    "        del df, target_seqs\n",
    "        \n",
    "        del target_features\n",
    "\n",
    "finally:\n",
    "    dist_store.close()\n",
    "    seq_store.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 2",
   "language": "python",
   "name": "python2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.15+"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}